{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import mnist_inference\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 定义神经网络结构相关的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100 \n",
    "LEARNING_RATE_BASE = 0.8\n",
    "LEARNING_RATE_DECAY = 0.99\n",
    "REGULARIZATION_RATE = 0.0001\n",
    "TRAINING_STEPS = 20000\n",
    "MOVING_AVERAGE_DECAY = 0.99 \n",
    "MODEL_SAVE_PATH = \"MNIST_model/\"\n",
    "MODEL_NAME = \"mnist_model\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义训练过程，支持程序关闭后从checkpoint恢复训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    # 定义checkpoint保存点\n",
    "    ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)\n",
    "    # 定义输入输出placeholder。\n",
    "    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')\n",
    "\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)\n",
    "    y = mnist_inference.inference(x, regularizer)\n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "    \n",
    "    # 定义损失函数、学习率、滑动平均操作以及训练过程。\n",
    "    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "    variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))\n",
    "    learning_rate = tf.train.exponential_decay(\n",
    "        LEARNING_RATE_BASE,\n",
    "        global_step,\n",
    "        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,\n",
    "        staircase=True)\n",
    "    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "    with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "        train_op = tf.no_op(name='train')\n",
    "        \n",
    "    # 初始化TensorFlow持久化类。\n",
    "    saver = tf.train.Saver()  \n",
    "    with tf.Session() as sess:\n",
    "        saved_step = 0\n",
    "        if ckpt and ckpt.model_checkpoint_path:\n",
    "            print(\"checkpoint存在，直接恢复变量\")\n",
    "            saver.restore(sess, ckpt.model_checkpoint_path)\n",
    "            # 恢复global_step\n",
    "            saved_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])\n",
    "            sess.run(global_step.assign(saved_step))\n",
    "        else:\n",
    "            print(\"checkpoint不存在，进行变量初始化\")\n",
    "            tf.global_variables_initializer().run()\n",
    "\n",
    "        for i in range(saved_step, TRAINING_STEPS):\n",
    "            xs, ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})\n",
    "            if i % 1000 == 0:\n",
    "                print(\"After %d training step(s), loss on training batch is %g.\" % (step, loss_value))\n",
    "                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)\n",
    "                last_step = step - 1000\n",
    "                if last_step > 0:\n",
    "                    try:\n",
    "                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+\"-\"+str(last_step)+\".index\")\n",
    "                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+\"-\"+str(last_step)+\".data-00000-of-00001\")\n",
    "                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+\"-\"+str(last_step)+\".meta\")\n",
    "                    except:\n",
    "                        print(\"删除数据异常\")\n",
    "                    else:\n",
    "                        print(\"成功删除：\", MODEL_SAVE_PATH+MODEL_NAME+\"-\"+str(last_step)+\".*\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 主程序入口。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../../datasets/MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "checkpoint不存在，进行变量初始化\n",
      "After 1 training step(s), loss on training batch is 2.82538.\n",
      "After 1001 training step(s), loss on training batch is 0.24735.\n",
      "成功删除： MNIST_model/mnist_model-1.*\n",
      "After 2001 training step(s), loss on training batch is 0.235086.\n",
      "成功删除： MNIST_model/mnist_model-1001.*\n",
      "After 3001 training step(s), loss on training batch is 0.187732.\n",
      "成功删除： MNIST_model/mnist_model-2001.*\n",
      "After 4001 training step(s), loss on training batch is 0.120792.\n",
      "成功删除： MNIST_model/mnist_model-3001.*\n",
      "After 5001 training step(s), loss on training batch is 0.106697.\n",
      "成功删除： MNIST_model/mnist_model-4001.*\n",
      "After 6001 training step(s), loss on training batch is 0.129981.\n",
      "成功删除： MNIST_model/mnist_model-5001.*\n",
      "After 7001 training step(s), loss on training batch is 0.0900439.\n",
      "成功删除： MNIST_model/mnist_model-6001.*\n",
      "After 8001 training step(s), loss on training batch is 0.0818427.\n",
      "成功删除： MNIST_model/mnist_model-7001.*\n",
      "After 9001 training step(s), loss on training batch is 0.0740065.\n",
      "成功删除： MNIST_model/mnist_model-8001.*\n",
      "After 10001 training step(s), loss on training batch is 0.0694076.\n",
      "成功删除： MNIST_model/mnist_model-9001.*\n",
      "After 11001 training step(s), loss on training batch is 0.0664568.\n",
      "成功删除： MNIST_model/mnist_model-10001.*\n",
      "After 12001 training step(s), loss on training batch is 0.0625552.\n",
      "成功删除： MNIST_model/mnist_model-11001.*\n",
      "After 13001 training step(s), loss on training batch is 0.0590671.\n",
      "成功删除： MNIST_model/mnist_model-12001.*\n",
      "After 14001 training step(s), loss on training batch is 0.0569511.\n",
      "成功删除： MNIST_model/mnist_model-13001.*\n",
      "After 15001 training step(s), loss on training batch is 0.0489733.\n",
      "成功删除： MNIST_model/mnist_model-14001.*\n",
      "After 16001 training step(s), loss on training batch is 0.0509863.\n",
      "成功删除： MNIST_model/mnist_model-15001.*\n",
      "After 17001 training step(s), loss on training batch is 0.0452694.\n",
      "成功删除： MNIST_model/mnist_model-16001.*\n",
      "After 18001 training step(s), loss on training batch is 0.0445102.\n",
      "成功删除： MNIST_model/mnist_model-17001.*\n",
      "After 19001 training step(s), loss on training batch is 0.0398916.\n",
      "成功删除： MNIST_model/mnist_model-18001.*\n"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
